import pymysql
from src import common
from src import config
import threading
from functools import wraps
from DBUtils.PooledDB import PooledDB


def cursor(f):
    def inner(self):
        conn = self.pool.connection()
        cur = conn.cursor()
        res = f(cur)
        conn.commit()
        cur.close()
        conn.close()
        return res
    return inner


class Database:
    def __init__(self, app):
        self.pool = PooledDB(
            creator=pymysql,   # 使用链接数据库的模块
            maxconnections=6,  # 连接池允许的最大连接数，0和None表示不限制连接数
            mincached=2,  # 初始化时，链接池中至少创建的空闲的链接，0表示不创建
            maxcached=5,  # 链接池中最多闲置的链接，0和None不限制
            maxshared=3,
            # 链接池中最多共享的链接数量，0和None表示全部共享。PS: 无用，因为pymysql和MySQLdb等模块的 threadsafety都为1，所有值无论设置为多少，_maxcached永远为0，所以永远是所有链接都共享。
            blocking=True,  # 连接池中如果没有可用连接后，是否阻塞等待。True，等待；False，不等待然后报错
            maxusage=None,  # 一个链接最多被重复使用的次数，None表示无限制
            setsession=[],  # 开始会话前执行的命令列表。如：["set datestyle to ...", "set time zone ..."]
            ping=0,
            # ping MySQL服务端，检查是否服务可用，如：0 = None = never, 1 = default = whenever it is requested, 2 = when a cursor is created, 4 = when a query is executed, 7 = always
            host=app.config['SQL_ADDRESS'],
            user=app.config['SQL_USER'],
            password=app.config['SQL_PASSWORD'],
            db=app.config['SQL_DATABASE'],
            port=app.config['SQL_PORT'],
            cursorclass=pymysql.cursors.DictCursor,
            charset='utf8'
        )
        self.lock = threading.Lock()

    def __construct_list_sql(self, table_name, args, filter_keys, condition_keys):
        _sql = "SELECT %s FROM %s " % (",".join(filter_keys) or'*', table_name)
        if(len(condition_keys)):
            _sql += "WHERE "
            condition_arr = []
            for key in condition_keys:
                condition = key + "='" + args[key] + "'"
                condition_arr.append(condition)
            _sql += ' and '.join(condition_arr)
        return _sql

    def __construct_insert_sql(self, table_name, args):
        _sql = ''
        _sql += "INSERT INTO "
        _sql += table_name
        _sql += "("
        for val in args:
            _sql += val + ','
        _sql = _sql[:-1]
        _sql += ")"
        _sql += "VALUES("
        for val in args:
            _sql += '"' + args[val] + '",'
        _sql = _sql[:-1]
        _sql += ")"
        return _sql

    def __construct_update_sql(self, table_name, args, cur_key="id"):
        _sql = ''
        _sql += "UPDATE "
        _sql += table_name
        _sql += " SET "
        for val in args:
            _sql += val
            _sql += "='"
            _sql += args[val] + "',"
        _sql = _sql[:-1]
        _sql += " WHERE "
        _sql += cur_key
        _sql += "="
        _sql += "'"
        _sql += args[cur_key]
        _sql += "'"
        return _sql

    def __construct_delete_sql(self, table_name, args, filter_key='id', cur_key="ids"):
        ids = args[cur_key] + ['']
        _sql = " DELETE FROM {}".format(table_name)
        _sql += ' WHERE {} in '.format(filter_key)
        _sql += tuple(ids).__str__()
        return _sql

    def __construct_left_join_sql(self, table1_name, table2_name, filter_keys, conditions):
        _sql = "SELECT %s FROM %s left join  %s on " % (
            ",".join(filter_keys) or '*', table1_name, table2_name)
        _sql += conditions
        return _sql

    def __construct_inner_join_sql(self, table_names=[], filter_keys=[], conditions=[]):
        _sql = 'SELECT %s FROM %s ' % (
            ','.join(filter_keys) or '*', ','.join(table_names))
        if(len(conditions)):
            _sql += "WHERE "
            condition_arr = []
            for condition in conditions:
                condition_arr.append(condition)
            _sql += ' and '.join(condition_arr)
        return _sql

    def get(self, table_name, args, cur_key='id'):  # 查询某一条记录
        @cursor
        def func(cur):
            _sql = "SELECT * FROM %s WHERE %s='%s'" % (
                table_name, cur_key, args[cur_key])
            cur.execute(_sql)
            result = cur.fetchone()
            return result
        return func(self)

    def get_by(self, table_name, args, filter_keys=[], condition_keys=['id']):
        @cursor
        def func(cur):
            _sql = self.__construct_list_sql(
                table_name, args, filter_keys, condition_keys)
            cur.execute(_sql)
            result = cur.fetchone()
            return result
        return func(self)

    def insert(self, table_name, args):  # 插入一条数据
        @cursor
        def func(cur):
            _sql = self.__construct_insert_sql(table_name, args)
            cur.execute(_sql)
            return {}
        return func(self)

    def list_by(self, table_name, args, filter_keys=[], condition_keys=['id']):
        conn = self.pool.connection()
        cur = conn.cursor()
        _sql = self.__construct_list_sql(
            table_name, args, filter_keys, condition_keys)
        print(_sql)

        cur.execute(_sql)

        result = cur.fetchall()
        cur.close()
        return result

    def list_left_join_by(self, left_table_name, right_table_name, filter_keys=[], condition=""):
        conn = self.pool.connection()
        cur = conn.cursor()
        _sql = self.__construct_left_join_sql(
            left_table_name, right_table_name, filter_keys, condition)
        print(_sql)

        cur.execute(_sql)

        result = cur.fetchall()
        cur.close()
        return result

    def list_inner_join_by(self, table_names=[], filter_keys=[], conditions=[]):
        @cursor
        def func(cur):
            _sql = self.__construct_inner_join_sql(
                table_names, filter_keys, conditions)
            cur.execute(_sql)
            result = cur.fetchall()
            return result
        return func(self)

    def list_all(self, table_name):  # 查询所有
        @cursor
        def func(cur):
            _sql = "SELECT * FROM %s" % (table_name)
            cur.execute(_sql)
            result = cur.fetchall()
            return result
        return func(self)

    def list_all_filter(self, table_name, filter_keys=[]):  # 筛选所有查询
        @cursor
        def func(cur):
            _sql = "SELECT %s FROM %s" % (",".join(filter_keys), table_name)
            cur.execute(_sql)
            result = result = cur.fetchall()
            return result
        return func(self)

    def list_like(self, table_name, args, cur_key):  # 按条件模糊查询
        @cursor
        def func(cur):
            _sql = "SELECT * FROM %s where %s LIKE %s%s%s" % (
                table_name, cur_key, "'%", args[cur_key], "%'")
            cur.execute(_sql)
            result = result = cur.fetchall()
            return result
        return func(self)

    def list_cur(self, table_name, args, cur_key):  # 按条件精确查询
        @cursor
        def func(cur):
            _sql = "SELECT * FROM %s where %s='%s'" % (
                table_name, cur_key, args[cur_key])
            cur.execute(_sql)
            result = cur.fetchall()
            return result

        return func(self)

    def list_cur_filter(self, table_name, args, cur_key, filter_keys=[]):  # 按条件精确筛选查询
        @cursor
        def func(cur):
            _sql = "SELECT  %s FROM %s where %s LIKE %s%s%s" % (
                ','.join(filter_keys), table_name, cur_key, "'%", args[cur_key], "%'")
            cur.execute(_sql)
            result = cur.fetchall()
            return result
        return func(self)

    def delete(self, table_name, args, cur_key='id'):
        @cursor
        def func(cur):
            _sql = "DELETE FROM %s WHERE %s='%s'" % (
                table_name, cur_key, args[cur_key])
            cur.execute(_sql)
            return {}
        return func(self)

    def multi_delete(self, table_name, args, filter_key='id', cur_key="ids"):
        @cursor
        def func(cur):
            _sql = self.__construct_delete_sql(
                table_name, args, filter_key=filter_key, cur_key=cur_key)
            print(_sql)
            cur.execute(_sql)
            return {}
        return func(self)

    def update(self, table_name, args, cur_key='id'):
        @cursor
        def func(cur):
            _sql = self.__construct_update_sql(
                table_name, args, cur_key=cur_key)
            cur.execute(_sql)
            return {}
        return func(self)

    def customize_sql(self, _sql):
        @cursor
        def func(cur):
            print(_sql)
            cur.execute(_sql)
            result = cur.fetchall()
            return result
        return func(self)
